import numpy as np
import matplotlib.pyplot as plt

class PDE1d:
    def __init__(self,dz,N,nT,H_prime_rho=None,V=None,W=None,save_data=True,Nsamples=100):
        '''dz: spacing between z coordinates for rho
           N:  number of points less than (and greater than) zero where z is defined
           nT: number of time steps for simulation'''
        x_start    = -(N-20+0.5)*dz
        x_stop     = (N+20+0.5)*dz
        rhos_      = np.arange(x_start,x_stop+0.0001,dz)
        nRho       = len(rhos_)
        z_i_list_  = np.arange(x_start,x_stop+0.0001,dz)

        self.dz       = dz
        self.nRho     = nRho
        self.z_i      = z_i_list_
        self.theta    = 2. # minmod parameter
        self.nT       = nT
        self.loss     = np.zeros(nT-1)
        self.nSamples = Nsamples

        self.H_prime_rho = H_prime_rho # diffusion function H_prime_rho(rho_east_west,rho_bar)
        self.V           = V # potential function V(x,z)
        self.W           = W # kernel function W(xj,xi,dx)
        if save_data:
            self.rho_history = np.zeros((nRho,nT))
        else:
            self.rho_history   = np.zeros((nRho,2)) # save at 3 time steps
            self.saveat      = [np.floor(nT/5.),np.floor(nT/3.)] # these are the time steps
        self.save_data = save_data


    
    def set_initial_distribution(self,f,g=None):
        '''f: function describing the PDF of the initial distribution for the negative labels
           g: function describing the PDF of the initial distribution for the positive labels'''
        self.rho0            = f(self.z_i)
        self.rho             = f(self.z_i) # distribution that we are modeling
        if hasattr(g, '__call__'):
            self.g0              = g(self.z_i) # used for the gradient of x
        normalized_rho0      = self.rho0 / np.sum(self.rho0)
        self.Zsamp = np.random.choice(self.z_i,size=self.nSamples,p=normalized_rho0)


    def update_RK(self,x,t,dt,samples):
        rho_bar  = self.rho
        k1       = self._F(rho_bar,              x) # +0
        k2       = self._F(rho_bar+dt/2*k1,      x) # +h/2
        k3       = self._F(rho_bar-dt*k1+2*dt*k2,x) # +h
        rho_bar += dt/6.*(k1 + 4*k2 + k3)
        rho_bar /= np.trapz(rho_bar,self.z_i)
        self.rho = rho_bar
        self.compute_loss(samples,x,t)
        if self.save_data:
            self.rho_history[:,t] = rho_bar
        elif t == self.saveat[0]:
            self.rho_history[:,0] = np.copy(rho_bar)
        elif t == self.saveat[1]:
            self.rho_history[:,1] = np.copy(rho_bar)
        return rho_bar


    def compute_loss(self,samples,x,t):
        '''compute cost of negative labels for the distribution
        not sure if this is correct...'''
        exp_term         = np.exp(-3.*(samples-x))
        loss_interaction = np.sum(np.log(1.+exp_term))
        moving_terms     = self.rho.reshape((self.nRho,1)) - samples.reshape((1,len(samples)))
        loss =  loss_interaction + np.sum(moving_terms)/len(samples)
        self.loss[t-1] = loss

    def _get_drho_dt(self,flux_list_):
        flux_plus_    = np.append(flux_list_,0)
        flux_minus_   = np.insert(flux_list_,0,0)
        drho_dt_list_ = -(flux_plus_-flux_minus_)/self.dz
        return drho_dt_list_


    def _minmod(self,rho_bar_,rho_bar_plus_1_,rho_bar_minus_1_):
        '''Returns rho_x_j'''
        theta_ = self.theta
        q1 = theta_ * (rho_bar_plus_1_-rho_bar_) / self.dz
        q2 = (rho_bar_plus_1_-rho_bar_minus_1_)/(2*self.dz)
        q3 = theta_ * (rho_bar_-rho_bar_minus_1_) / self.dz

        negative_idx = (q1 < 0) & (q2 < 0) & (q3 < 0)
        positive_idx = (q1 > 0) & (q2 > 0) & (q3 > 0)

        minmod_array = np.zeros(len(rho_bar_))
        minmod_array[negative_idx] = np.maximum(np.maximum(q1[negative_idx],q2[negative_idx]),q3[negative_idx])
        minmod_array[positive_idx] = np.minimum(np.minimum(q1[positive_idx],q2[positive_idx]),q3[positive_idx])
        minmod_array[0] = theta_ * (rho_bar_plus_1_[0]-rho_bar_[0]) / self.dz
        minmod_array[-1] = theta_ * (rho_bar_[-1]-rho_bar_minus_1_[-1]) / self.dz

        return minmod_array


    def _get_minmod(self,rho_bar_):
        rho_bar_plus_1_  = np.append(rho_bar_[1:],0)
        rho_bar_minus_1_ = np.insert(rho_bar_[:-1],0,0)
        rho_x            = self._minmod(rho_bar_,rho_bar_plus_1_,rho_bar_minus_1_)

        return rho_x


    def _compute_F_and_u(self,xi_list_,rho_bar,rho_east_,rho_west_):
        dz           = self.dz
        xi_plus_one_ = xi_list_[1:]
        u            = -(xi_plus_one_-xi_list_[:-1])/dz
        rho_e        = rho_east_[:-1]
        rho_w        = rho_west_[1:]
        for_u_plus   = u*rho_e
        for_u_minus  = u*rho_w

        for_u_plus  -= self.H_prime_rho(rho_e,rho_bar[1:]) / dz - self.H_prime_rho(rho_e,rho_bar[:-1]) / dz
        for_u_minus -= self.H_prime_rho(rho_w,rho_bar[1:]) / dz - self.H_prime_rho(rho_w,rho_bar[:-1]) / dz
        u_plus = np.maximum(for_u_plus,0)
        u_minus = np.minimum(for_u_minus,0)
        return u_plus + u_minus


    def _evaluate_W(self,rho_bar_list,W_symmetric=True):
        nn      = len(rho_bar_list)
        dz      = self.dz
        xi_list = np.zeros(nn)

        if W_symmetric:
            # this is a lot faster - only 1 for loop
            W_sym_function = lambda x,dx: self.W(0,x,dx)*rho_bar_list
            for jj in range(nn):
                xj = self.z_i[jj]
                dist = np.abs(self.z_i-xj)
                xi_list[jj] = np.sum(W_sym_function(dist,dz))
        else:
            for jj in range(nn):
                total_ = 0
                xj = self.z_i[jj]
                for ii in range(nn):
                    xi = self.z_i[ii]
                    rho_bar_i = rho_bar_list[ii]
                    total_ += self.W(xj,xi,dz) * rho_bar_i
                xi_list[jj] = total_
        
        return xi_list


    def _F(self,rho_bar,x):
        '''Return F(rho) for RK method
        x is the external influence'''
        dz       = self.dz
        xi       = self.V(x, self.z_i,self.Zsamp) + self._evaluate_W(rho_bar) # we are using H_prime but can't use this discretization
        rho_x_   = self._get_minmod(rho_bar)
        rho_west = rho_bar - dz*rho_x_/2
        rho_east = rho_bar + dz*rho_x_/2
        flux     = self._compute_F_and_u(xi,rho_bar,rho_east,rho_west)
        drhodt   = self._get_drho_dt(flux)
        drhodt[-1] = 0
        drhodt[0]  = 0

        return drhodt

    # def _evaluate_loss(self,t):
    #     '''compute loss at current time step, save'''
    #     rho = self.rho
    #     pass
